/**
 * 
 */
package edu.umd.clip.lm.tools;

import java.io.*;
import java.nio.charset.Charset;
import java.util.*;
import java.util.regex.*;

import edu.berkeley.nlp.util.*;
import edu.umd.clip.jobs.JobManager;
import edu.umd.clip.lm.util.algo.MDI;
import edu.umd.clip.lm.util.algo.MDIClusterNotifier;
import edu.umd.clip.lm.util.tree.AnyaryTree;

/**
 * @author Denis Filimonov <den@cs.umd.edu>
 *
 */
public class ClusterWords {

	public static class Options {
        @Option(name = "-input", required = false, usage = "Training text file (Default: stdin)")
		public String input;
        @Option(name = "-output", required = false, usage = "Class output file (Default: stdout)")
		public String output;
        @Option(name = "-config", usage = "XML config file")
		public String config;
        @Option(name = "-jobs", usage = "number of concurrent jobs (default: 1)")
        public int jobs = 1;
        @Option(name = "-no-cluster", required = false, usage = "List of words not to cluster")
        public String noCluster;
        @Option(name = "-levels", required = true, usage = "Number of clusters is 2^levels")
        public int levels;
	}
	
	private static final String SENT_START = "<s>";
	private static final String SENT_END = "</s>";
	
	private static final Pattern wordSplitRE = Pattern.compile("\\s+");
	
	/**
	 * @param args
	 * @throws IOException 
	 */
	public static void main(String[] args) throws IOException {
        OptionParser optParser = new OptionParser(Options.class);
        final Options opts = (Options) optParser.parse(args, true);

        JobManager.initialize(opts.jobs);
		Thread thread = new Thread(JobManager.getInstance(), "Job Manager");
		thread.setDaemon(true);
		thread.start();

        HashMap<String,String> noClusterDict = new HashMap<String,String>();
        HashMap<String,String> dictionary = new HashMap<String,String>(10000);
        
        HashMap<String, MutableDouble> unigramCounts = new HashMap<String,MutableDouble>();
        
		HashMap<String, HashMap<String, MutableDouble>> leftCounts;
		HashMap<String, HashMap<String, MutableDouble>> rightCounts;
		leftCounts = new HashMap<String, HashMap<String, MutableDouble>>(10000);
		rightCounts = new HashMap<String, HashMap<String, MutableDouble>>(10000);

		BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(opts.output == null ? System.out : new FileOutputStream(opts.output), 
				Charset.forName("UTF-8")));

		double totalCounts = 0.0;
		try {
			if (opts.noCluster != null) {
				InputStreamReader reader = new InputStreamReader(new FileInputStream(opts.noCluster), Charset.forName("UTF-8"));
				BufferedReader inputData = new BufferedReader(reader);
				for(String line = inputData.readLine(); line != null; line = inputData.readLine()) {
					line = line.trim();
					noClusterDict.put(line, line);
					dictionary.put(line, line);
				}
				inputData.close();
			}
			noClusterDict.put(SENT_START, SENT_START);
			noClusterDict.put(SENT_END, SENT_END);
			dictionary.put(SENT_START, SENT_START);
			dictionary.put(SENT_END, SENT_END);
			
			
			InputStreamReader reader = new InputStreamReader(opts.input == null ? System.in : new FileInputStream(opts.input), 
					Charset.forName("UTF-8"));
			BufferedReader inputData = new BufferedReader(reader);

			for(String line = inputData.readLine(); line != null; line = inputData.readLine()) {
				line = line.trim();
				if (line.isEmpty()) continue;
				
				String realWords[] = wordSplitRE.split(line);
				
				String sentence[] = new String[realWords.length + 2];
				sentence[0] = SENT_START;
				sentence[sentence.length-1] = SENT_END;
				
				for(int i=0; i<realWords.length; ++i) {
					String word = dictionary.get(realWords[i]);
					if (word == null) {
						word = realWords[i];
						dictionary.put(word, word);
					}
					
					MutableDouble count = unigramCounts.get(word);
					if (count == null) {
						count = new MutableDouble();
						unigramCounts.put(word, count);
					}
					count.add(1);
					
					sentence[i+1] = word;
					totalCounts += 1;
				}
				
				
				for(int i=1; i<sentence.length-1; ++i) {
					String word = sentence[i];

					if (noClusterDict.containsKey(word)) continue;
					
					String prevWord = sentence[i-1];
					String nextWord = sentence[i+1];
					
					HashMap<String, MutableDouble> leftMap = leftCounts.get(word);
					if (leftMap == null) {
						leftMap = new HashMap<String, MutableDouble>();
						leftCounts.put(word, leftMap);
					}
					MutableDouble count = leftMap.get(prevWord);
					if (count == null) {
						leftMap.put(prevWord, new MutableDouble(1.0));
					} else {
						count.add(1.0);
					}

					HashMap<String, MutableDouble> rightMap = rightCounts.get(word);
					if (rightMap == null) {
						rightMap = new HashMap<String, MutableDouble>();
						rightCounts.put(word, rightMap);
					}
					count = rightMap.get(nextWord);
					if (count == null) {
						rightMap.put(nextWord, new MutableDouble(1.0));
					} else {
						count.add(1.0);
					}
						
					//System.out.println("size = " + Integer.toString(counts.size()));
				}
			}
		} catch(IOException e) {
			e.printStackTrace();
		}
				
		for(HashMap<String, MutableDouble> entry : leftCounts.values()) {
			// normalize
			for(MutableDouble count : entry.values()) {
				count.set(count.doubleValue() / totalCounts);
			}
		}

		for(HashMap<String, MutableDouble> entry : rightCounts.values()) {
			// normalize
			for(MutableDouble count : entry.values()) {
				count.set(count.doubleValue() / totalCounts);
			}
		}

		if (noClusterDict.size() == dictionary.size()) {
			System.err.println("noClusterDict.size() == dictionary.size()");
			return;
		}
		
		final Map<Collection<String>, AnyaryTree<String>> treeNodes = new IdentityHashMap<Collection<String>, AnyaryTree<String>>();
		MDIClusterNotifier<String> MDINotifier = new MDIClusterNotifier<String>() {

			/* (non-Javadoc)
			 * @see edu.umd.clip.lm.util.algo.MDIClusterNotifier#notify(java.util.Collection, java.util.Collection, java.util.Collection)
			 */
			public synchronized boolean notify(Collection<String> oldCluster, Collection<String> cluster1, Collection<String> cluster2) {
				AnyaryTree<String> node = treeNodes.get(oldCluster);
				String label1 = null;
				if (cluster1.size() == 1) {
					label1 = cluster1.iterator().next();
				}
				AnyaryTree<String> subnode1 = new AnyaryTree<String>(label1);
				node.addChild(subnode1);
				
				String label2 = null;
				if (cluster2.size() == 1) {
					label2 = cluster2.iterator().next();
				}
				AnyaryTree<String> subnode2 = new AnyaryTree<String>(label2);
				node.addChild(subnode2);
				
				treeNodes.put(cluster1, subnode1);
				treeNodes.put(cluster2, subnode2);
				
				// compute the level
				int level = 1;
				{
					AnyaryTree<String> n = node;
					while(n != null) {
						++level;
						n = n.getParent();
					}
				}
				
				System.err.printf("[%d] split %d words into %d and %d\n", level, oldCluster.size(), cluster1.size(), cluster2.size());
				
				return level <= opts.levels;
			}
		};
		
		AnyaryTree<String> theRoot = null;
		HashSet<String> clusterVocab = new HashSet<String>(dictionary.size() - noClusterDict.size());
		
		for(String word : dictionary.keySet()) {
			if (! noClusterDict.containsKey(word)) {
				clusterVocab.add(word);
			}
		}

		AnyaryTree<String> treeRoot = new AnyaryTree<String>(null);
		treeNodes.put(clusterVocab, treeRoot);
		if (theRoot == null) {
			theRoot = treeRoot;
		}

		MDI<String,String> algo = new MDI<String,String>(clusterVocab, dictionary.keySet());
		algo.setNotifier(MDINotifier);

		for(Map.Entry<String, HashMap<String, MutableDouble>> entry : leftCounts.entrySet()) {
			String word = entry.getKey();
			for(Map.Entry<String, MutableDouble> entry2 : entry.getValue().entrySet()) {
				algo.setLeftProb(word, entry2.getKey(), entry2.getValue().doubleValue());
			}
		}
		leftCounts = null;
		
		for(Map.Entry<String, HashMap<String, MutableDouble>> entry : rightCounts.entrySet()) {
			String word = entry.getKey();
			for(Map.Entry<String, MutableDouble> entry2 : entry.getValue().entrySet()) {
				algo.setRightProb(word, entry2.getKey(), entry2.getValue().doubleValue());
			}
		}
		rightCounts = null;
		
		algo.normalizeDistributions();
		algo.partition(clusterVocab);

		
		int nrClusters = 0;
		for(Map.Entry<Collection<String>, AnyaryTree<String>> e : treeNodes.entrySet()) {
			AnyaryTree<String> node = e.getValue();
			if (node.isLeaf()) {
				Collection<String> words = e.getKey();
				++nrClusters;
				String clusterName = String.format("CLASS-%05d", nrClusters);
				
				double totalClusterCount = 0;
				for(String word : words) {
					totalClusterCount += unigramCounts.get(word).doubleValue();
				}
				for(String word : words) {
					double prob = unigramCounts.get(word).doubleValue() / totalClusterCount;
					
					writer.append(String.format("%s %f %s\n", clusterName, prob, word));
				}
			}
		}
		
		writer.close();
	}

}
